{
  "cells": [
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "b518b04cbfe0"
      },
      "source": [
        "##### Copyright 2020 The TensorFlow Authors."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "cellView": "form",
        "id": "906e07f6e562"
      },
      "outputs": [],
      "source": [
        "#@title Licensed under the Apache License, Version 2.0 (the \"License\");\n",
        "# you may not use this file except in compliance with the License.\n",
        "# You may obtain a copy of the License at\n",
        "#\n",
        "# https://www.apache.org/licenses/LICENSE-2.0\n",
        "#\n",
        "# Unless required by applicable law or agreed to in writing, software\n",
        "# distributed under the License is distributed on an \"AS IS\" BASIS,\n",
        "# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n",
        "# See the License for the specific language governing permissions and\n",
        "# limitations under the License."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "a81c428fc2d3"
      },
      "source": [
        "# 転移学習とファインチューニング"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "3e5a59f0aefd"
      },
      "source": [
        "<table class=\"tfo-notebook-buttons\" align=\"left\">\n",
        "  <td><a target=\"_blank\" href=\"https://www.tensorflow.org/guide/keras/transfer_learning\"><img src=\"https://www.tensorflow.org/images/tf_logo_32px.png\">TensorFlow.org で実行</a></td>\n",
        "  <td>     <a target=\"_blank\" href=\"https://colab.research.google.com/github/tensorflow/docs-l10n/blob/master/site/ja/guide/keras/transfer_learning.ipynb\"><img src=\"https://www.tensorflow.org/images/colab_logo_32px.png\">\tGoogle Colabで実行</a>\n",
        "</td>\n",
        "  <td><a target=\"_blank\" href=\"https://github.com/tensorflow/docs-l10n/blob/master/site/ja/guide/keras/transfer_learning.ipynb\">     <img src=\"https://www.tensorflow.org/images/GitHub-Mark-32px.png\">     GitHubでソースを表示</a></td>\n",
        "  <td>     <a href=\"https://storage.googleapis.com/tensorflow_docs/docs-l10n/site/ja/guide/keras/transfer_learning.ipynb\"><img src=\"https://www.tensorflow.org/images/download_logo_32px.png\">ノートブックをダウンロード</a>\n",
        "</td>\n",
        "</table>"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "8d4ac441b1fc"
      },
      "source": [
        "## セットアップ"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "9a7e9b92f963"
      },
      "outputs": [],
      "source": [
        "import numpy as np\n",
        "import tensorflow as tf\n",
        "from tensorflow import keras"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "00d4c41cfe2f"
      },
      "source": [
        "## はじめに\n",
        "\n",
        "ある問題で学習した特徴量を取り入れ、それを新しい類似した問題に利用する方法を**転移学習**と呼びます。たとえば、アライグマの識別を学習したモデルの特徴量がある場合、それを使用してタヌキの識別を学習するモデルに取り組むことができます。\n",
        "\n",
        "通常、転移学習はデータセットのデータが少なすぎてフルスケールモデルをゼロからトレーニングできないようなタスクで行われます。\n",
        "\n",
        "ディープラーニングの文脈では、転移学習は次のワークフローで行われるのが最も一般的です。\n",
        "\n",
        "1. 以前にトレーニングされたモデルからレイヤーを取得します。\n",
        "2. 以降のトレーニングラウンドでそれらのレイヤーに含まれる情報が破損しないように凍結します。\n",
        "3. 凍結したレイヤーの上にトレーニング対象のレイヤーを新たに追加します。これらのレイヤーは古い特徴量を新しいデータセットの予測に変換することを学習します。\n",
        "4. データセットで新しいレイヤーをトレーニングします。\n",
        "\n",
        "最後に任意で**ファインチューニング**を実施できます。ファインチューニングでは、上記で取得したモデル全体（または一部）を解凍し、新しいデータに対して非常に低い学習率で再トレーニングします。これを実施すると、事前トレーニング済みの特徴量を徐々に新しいデータに適応させ、意味のある改善を得られることがあります。\n",
        "\n",
        "まずは、転移学習とファインチューニングのほとんどのワークフローの基礎である、Keras の `trainable` API について詳しく見てみましょう。\n",
        "\n",
        "次に、一般的なワークフローを説明します。ImageNet データセットで事前にトレーニングされたモデルを取得してそれを Kaggle の「犬と猫」分類データセットで再トレーニングしてみましょう。\n",
        "\n",
        "これは、[Deep Learning with Python](https://www.manning.com/books/deep-learning-with-python) および 2016 年のブログ記事[「少ないデータで強力な画像分類モデルを構築する」](https://blog.keras.io/building-powerful-image-classification-models-using-very-little-data.html)を基にしています。"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "fbf8630c325b"
      },
      "source": [
        "## レイヤーの凍結: `trainable` 属性を理解する\n",
        "\n",
        "レイヤーとモデルには 3 つの重み属性があります。\n",
        "\n",
        "- `weights` は、レイヤーのすべての重み変数のリストです。\n",
        "- `trainable_weights` は、トレーニング中の損失を最小限に抑えるために（勾配降下を介して）更新を意図した重みのリストです。\n",
        "- `non_trainable_weights` は、トレーニングを意図していない重みのリストです。通常は、フォワードパスの間にモデルによって更新されます。\n",
        "\n",
        "**例: `Dense` レイヤーに、トレーニング対象の重みが 2 つ（カーネルとバイアス）がある**"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "407deab1855e"
      },
      "outputs": [],
      "source": [
        "layer = keras.layers.Dense(3)\n",
        "layer.build((None, 4))  # Create the weights\n",
        "\n",
        "print(\"weights:\", len(layer.weights))\n",
        "print(\"trainable_weights:\", len(layer.trainable_weights))\n",
        "print(\"non_trainable_weights:\", len(layer.non_trainable_weights))"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "79fcb9cc960d"
      },
      "source": [
        "一般的に、重みはすべてトレーニング対象の重みです。トレーニング対象外の重みレイヤーを持つ組み込みレイヤーは、`BatchNormalization` レイヤーしかありません。これはトレーニング対象外の重みを使用して、トレーニング中の入力の平均と分散を追跡します。独自のカスタムレイヤーでトレーニング対象外の重みを使用する方法については、[レイヤーの新規作成ガイド](https://keras.io/guides/making_new_layers_and_models_via_subclassing/)をご覧ください。\n",
        "\n",
        "**例: `BatchNormalization` レイヤーに 、トレーニング対象の重みとトレーニング対象外の重みが 2 つずつある**"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "fbc87a09bc3c"
      },
      "outputs": [],
      "source": [
        "layer = keras.layers.BatchNormalization()\n",
        "layer.build((None, 4))  # Create the weights\n",
        "\n",
        "print(\"weights:\", len(layer.weights))\n",
        "print(\"trainable_weights:\", len(layer.trainable_weights))\n",
        "print(\"non_trainable_weights:\", len(layer.non_trainable_weights))"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "cddcdbf2bd5b"
      },
      "source": [
        "レイヤーとモデルには、ブール属性の `trainable` もあり、その値を変更することができます。`layer.trainable` を `False` に設定すると、すべてのレイヤーの重みがトレーニング対象からトレーニング対象外に移動されます。これはレイヤーの「凍結」と呼ばれるもので、凍結されたレイヤーの状態はトレーニング中に更新されません（`fit()` でトレーニングする場合、または`trainable_weights` に依存して勾配の更新を適用するカスタムループでトレーニングする場合）。\n",
        "\n",
        "**例: `trainable` を `False`に設定する**"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "51bbc5d12742"
      },
      "outputs": [],
      "source": [
        "layer = keras.layers.Dense(3)\n",
        "layer.build((None, 4))  # Create the weights\n",
        "layer.trainable = False  # Freeze the layer\n",
        "\n",
        "print(\"weights:\", len(layer.weights))\n",
        "print(\"trainable_weights:\", len(layer.trainable_weights))\n",
        "print(\"non_trainable_weights:\", len(layer.non_trainable_weights))"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "32904f9a58db"
      },
      "source": [
        "トレーニング対象の重みがトレーニング対象外の重みになると、その値はトレーニング中に更新されなくなります。"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "3c26c27a8291"
      },
      "outputs": [],
      "source": [
        "# Make a model with 2 layers\n",
        "layer1 = keras.layers.Dense(3, activation=\"relu\")\n",
        "layer2 = keras.layers.Dense(3, activation=\"sigmoid\")\n",
        "model = keras.Sequential([keras.Input(shape=(3,)), layer1, layer2])\n",
        "\n",
        "# Freeze the first layer\n",
        "layer1.trainable = False\n",
        "\n",
        "# Keep a copy of the weights of layer1 for later reference\n",
        "initial_layer1_weights_values = layer1.get_weights()\n",
        "\n",
        "# Train the model\n",
        "model.compile(optimizer=\"adam\", loss=\"mse\")\n",
        "model.fit(np.random.random((2, 3)), np.random.random((2, 3)))\n",
        "\n",
        "# Check that the weights of layer1 have not changed during training\n",
        "final_layer1_weights_values = layer1.get_weights()\n",
        "np.testing.assert_allclose(\n",
        "    initial_layer1_weights_values[0], final_layer1_weights_values[0]\n",
        ")\n",
        "np.testing.assert_allclose(\n",
        "    initial_layer1_weights_values[1], final_layer1_weights_values[1]\n",
        ")"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "412d7d659aa1"
      },
      "source": [
        "`layer.trainable` 属性を `layer.__call__()` の引数 `training` と混同しないようにしてください（後者は、レイヤーがフォワードパスを推論モードで実行するか、トレーニングモードで実行するかを制御します）。詳細については、[Keras よくある質問](https://keras.io/getting_started/faq/#whats-the-difference-between-the-training-argument-in-call-and-the-trainable-attribute)をご覧ください。"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "e6ccd3c7ab1a"
      },
      "source": [
        "## `trainable` 属性を再帰的に設定する\n",
        "\n",
        "モデルや、サブレイヤーのあるレイヤーで `trainable = False` を設定すると、すべての子レイヤーもトレーニング対象外になります。\n",
        "\n",
        "**例:**"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "4235d0c69821"
      },
      "outputs": [],
      "source": [
        "inner_model = keras.Sequential(\n",
        "    [\n",
        "        keras.Input(shape=(3,)),\n",
        "        keras.layers.Dense(3, activation=\"relu\"),\n",
        "        keras.layers.Dense(3, activation=\"relu\"),\n",
        "    ]\n",
        ")\n",
        "\n",
        "model = keras.Sequential(\n",
        "    [keras.Input(shape=(3,)), inner_model, keras.layers.Dense(3, activation=\"sigmoid\"),]\n",
        ")\n",
        "\n",
        "model.trainable = False  # Freeze the outer model\n",
        "\n",
        "assert inner_model.trainable == False  # All layers in `model` are now frozen\n",
        "assert inner_model.layers[0].trainable == False  # `trainable` is propagated recursively"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "61535ba76727"
      },
      "source": [
        "## 典型的な転移学習のワークフロー\n",
        "\n",
        "ここでは、典型的な転移学習のワークフローを Keras に実装する方法を示します。\n",
        "\n",
        "1. ベースモデルをインスタンス化し、それに事前トレーニング済みの重みを読み込みます。\n",
        "2. `trainable = False` を設定して、ベースモデルのすべてのレイヤーを凍結します。\n",
        "3. ベースモデルの 1 つ以上のレイヤーの出力上に新しいモデルを作成します。\n",
        "4. 新しいデータセットで新しいモデルをトレーニングします。\n",
        "\n",
        "これに代わる、より軽量なワークフローとして、以下のようなものも考えられます。\n",
        "\n",
        "1. ベースモデルをインスタンス化し、それに事前トレーニング済みの重みを読み込みます。\n",
        "2. 新しいデータセットを実行して、ベースモデルの 1 つ以上のレイヤーの出力を記録します。**特徴量抽出**と呼ばれる作業です。\n",
        "3. その出力を新しい小さなモデルの入力データとして使用します。\n",
        "\n",
        "この 2 番目のワークフローには、トレーニングのエポックごとに 1 回ではなく、データに対して 1 回だけベースモデルを実行するため、はるかに高速で安価になるというメリットがあります。\n",
        "\n",
        "ただし、トレーニング中に新しいモデルの入力データを動的に変更することができないという問題があります。これはデータを拡張する際などに必要なことです。転移学習は通常、フルスケールでモデルを新規にトレーニングするには新しいデータセットのデータ量が少なすぎる場合に使用しますが、そのような場合、データの拡張が非常に重要になります。そこで以降では、1 番目のワークフローに焦点を当てます。\n",
        "\n",
        "1 番目のワークフローは、Keras では以下のようになります。\n",
        "\n",
        "まず最初に、事前トレーニング済みの重みを使用してベースモデルをインスタンス化します。\n",
        "\n",
        "```python\n",
        "base_model = keras.applications.Xception(\n",
        "    weights='imagenet',  # Load weights pre-trained on ImageNet.\n",
        "    input_shape=(150, 150, 3),\n",
        "    include_top=False)  # Do not include the ImageNet classifier at the top.\n",
        "```\n",
        "\n",
        "次に、ベースモデルを凍結します。\n",
        "\n",
        "```python\n",
        "base_model.trainable = False\n",
        "```\n",
        "\n",
        "その上に新しいモデルを作成します。\n",
        "\n",
        "```python\n",
        "inputs = keras.Input(shape=(150, 150, 3))\n",
        "# We make sure that the base_model is running in inference mode here,\n",
        "# by passing `training=False`. This is important for fine-tuning, as you will\n",
        "# learn in a few paragraphs.\n",
        "x = base_model(inputs, training=False)\n",
        "# Convert features of shape `base_model.output_shape[1:]` to vectors\n",
        "x = keras.layers.GlobalAveragePooling2D()(x)\n",
        "# A Dense classifier with a single unit (binary classification)\n",
        "outputs = keras.layers.Dense(1)(x)\n",
        "model = keras.Model(inputs, outputs)\n",
        "```\n",
        "\n",
        "新しいデータでモデルをトレーニングします。\n",
        "\n",
        "```python\n",
        "model.compile(optimizer=keras.optimizers.Adam(),\n",
        "              loss=keras.losses.BinaryCrossentropy(from_logits=True),\n",
        "              metrics=[keras.metrics.BinaryAccuracy()])\n",
        "model.fit(new_dataset, epochs=20, callbacks=..., validation_data=...)\n",
        "```"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "736c99aea690"
      },
      "source": [
        "## ファインチューニング\n",
        "\n",
        "モデルが新しいデータで収束したら、ベースモデルのすべてまたは一部を解凍して、非常に低い学習率でエンドツーエンドでモデル全体を再トレーニングすることができます。\n",
        "\n",
        "これは任意に行える最後のステップではありますが、段階的な改善を期待することができます。ただし、すぐに過適合になる可能性もあることに注意してください。\n",
        "\n",
        "このステップは、凍結レイヤーのあるモデルが収束するまでトレーニングされた*後に*のみ行うことが重要です。ランダムに初期化されたトレーニング対象レイヤーと事前にトレーニングされた特徴量を持つトレーニング対象レイヤーを混ぜると、トレーニング中に、ランダムに初期化されたレイヤーによって非常に大きな勾配の更新が発生し、事前にトレーニングされた特徴量が破損してしまうことになります。\n",
        "\n",
        "また、この段階では学習率が非常に低いことも重要です。1 回目のトレーニングよりもはるかに大きなモデルを、非常に小さなデータセットでトレーニングするからです。その結果、大量の重みの更新を適用すると、あっという間に過適合が起きてしまう危険性があります。ここでは、事前トレーニング済みの重みを段階的に適応し直します。\n",
        "\n",
        "ベースモデル全体のファインチューニングを実装するには、以下のようにします。\n",
        "\n",
        "```python\n",
        "# Unfreeze the base model\n",
        "base_model.trainable = True\n",
        "\n",
        "# It's important to recompile your model after you make any changes\n",
        "# to the `trainable` attribute of any inner layer, so that your changes\n",
        "# are take into account\n",
        "model.compile(optimizer=keras.optimizers.Adam(1e-5),  # Very low learning rate\n",
        "              loss=keras.losses.BinaryCrossentropy(from_logits=True),\n",
        "              metrics=[keras.metrics.BinaryAccuracy()])\n",
        "\n",
        "# Train end-to-end. Be careful to stop before you overfit!\n",
        "model.fit(new_dataset, epochs=10, callbacks=..., validation_data=...)\n",
        "```\n",
        "\n",
        "**`compile()` および `trainable` に関する重要な注意事項**\n",
        "\n",
        "モデルで `compile()` を呼び出すと、そのモデルの動作が「凍結」されます。これは、モデルがコンパイルされたときの `trainable` 属性の値は、`compile` が再び呼び出されるまで、そのモデルの寿命が続く限り保持されるということです。したがって、`trainable` の値を変更した場合には、その内容が考慮されるように必ずモデルでもう一度 `compile()` を呼び出してください。\n",
        "\n",
        "**`BatchNormalization` レイヤーに関する重要な注意事項**\n",
        "\n",
        "多くの画像モデルには `BatchNormalization` レイヤーが含まれています。このレイヤーは、あらゆる点において特殊なケースです。ここにいくつかの注意点を示します。\n",
        "\n",
        "- `BatchNormalization` には、トレーニング中に更新されるトレーニング対象外の重みが 2 つ含まれています。これらは入力の平均と分散を追跡する変数です。\n",
        "- `bn_layer.trainable = False` を設定すると、`BatchNormalization` レイヤーは推論モードで実行されるため、その平均と分散の統計は更新されません。[重みのトレーナビリティと推論/トレーニングモードは 2 つの直交する概念](https://keras.io/getting_started/faq/#whats-the-difference-between-the-training-argument-in-call-and-the-trainable-attribute)であるため、これは一般的には他のレイヤーには当てはまりませんが、`BatchNormalization` レイヤーの場合は、この 2 つは関連しています。\n",
        "- ファインチューニングを行うために  `BatchNormalization` レイヤーを含むモデルを解凍する場合、ベースモデルを呼び出す際に `training = False` を渡して `BatchNormalization` レイヤーを推論モードにしておく必要があります。推論モードになっていない場合、トレーニング対象外の重みに適用された更新によって、モデルが学習したものが突然破壊されてしまいます。\n",
        "\n",
        "このガイドの最後にあるエンドツーエンドの例で、このパターンの動作を確認することができます。\n"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "bce9ffc4e290"
      },
      "source": [
        "## カスタムトレーニングループで転移学習とファインチューニングをする\n",
        "\n",
        "`fit()` の代わりに独自の低レベルのトレーニングループを使用している場合でも、ワークフローは基本的に同じです。勾配の更新を適用する際には、`model.trainable_weights` のリストのみを考慮するように注意する必要があります。\n",
        "\n",
        "```python\n",
        "# Create base model\n",
        "base_model = keras.applications.Xception(\n",
        "    weights='imagenet',\n",
        "    input_shape=(150, 150, 3),\n",
        "    include_top=False)\n",
        "# Freeze base model\n",
        "base_model.trainable = False\n",
        "\n",
        "# Create new model on top.\n",
        "inputs = keras.Input(shape=(150, 150, 3))\n",
        "x = base_model(inputs, training=False)\n",
        "x = keras.layers.GlobalAveragePooling2D()(x)\n",
        "outputs = keras.layers.Dense(1)(x)\n",
        "model = keras.Model(inputs, outputs)\n",
        "\n",
        "loss_fn = keras.losses.BinaryCrossentropy(from_logits=True)\n",
        "optimizer = keras.optimizers.Adam()\n",
        "\n",
        "# Iterate over the batches of a dataset.\n",
        "for inputs, targets in new_dataset:\n",
        "    # Open a GradientTape.\n",
        "    with tf.GradientTape() as tape:\n",
        "        # Forward pass.\n",
        "        predictions = model(inputs)\n",
        "        # Compute the loss value for this batch.\n",
        "        loss_value = loss_fn(targets, predictions)\n",
        "\n",
        "    # Get gradients of loss wrt the *trainable* weights.\n",
        "    gradients = tape.gradient(loss_value, model.trainable_weights)\n",
        "    # Update the weights of the model.\n",
        "    optimizer.apply_gradients(zip(gradients, model.trainable_weights))\n",
        "```"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "4e63ba34ce1c"
      },
      "source": [
        "ファインチューニングの場合も同様です。"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "852447087ba9"
      },
      "source": [
        "## エンドツーエンドの例: 猫と犬データセットの画像分類モデルのファインチューニング\n",
        "\n",
        "この概念を固めるために、具体的なエンドツーエンドの転移学習とファインチューニングの例を見てみましょう。ImageNet で事前トレーニングされた Xception モデルを読み込み、Kaggleの 「犬と猫」分類データセットで使用します。"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "ba75835e0de6"
      },
      "source": [
        "### データを取得する\n",
        "\n",
        "まず、TFDS を使用して犬と猫のデータセットを取得してみましょう。独自のデータセットをお持ちの場合は、`tf.keras.preprocessing.image_dataset_from_directory` ユーティリティを使用して、クラス固有のフォルダにファイル作成されたディスク上の画像集合から類似のラベル付きデータセットオブジェクトを生成することもできます。\n",
        "\n",
        "転移学習は、非常に小さなデータセットを扱う場合に最も有用です。データセットを小さく保つために、元のトレーニングデータ（画像 25,000 枚）の 40 ％をトレーニングに、10 ％を検証に、10 ％をテストに使用します。"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "1a99f56934f7"
      },
      "outputs": [],
      "source": [
        "import tensorflow_datasets as tfds\n",
        "\n",
        "tfds.disable_progress_bar()\n",
        "\n",
        "train_ds, validation_ds, test_ds = tfds.load(\n",
        "    \"cats_vs_dogs\",\n",
        "    # Reserve 10% for validation and 10% for test\n",
        "    split=[\"train[:40%]\", \"train[40%:50%]\", \"train[50%:60%]\"],\n",
        "    as_supervised=True,  # Include labels\n",
        ")\n",
        "\n",
        "print(\"Number of training samples: %d\" % tf.data.experimental.cardinality(train_ds))\n",
        "print(\n",
        "    \"Number of validation samples: %d\" % tf.data.experimental.cardinality(validation_ds)\n",
        ")\n",
        "print(\"Number of test samples: %d\" % tf.data.experimental.cardinality(test_ds))"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "9db548603642"
      },
      "source": [
        "ここにトレーニングデータセットの最初の 9 枚の画像があります。ご覧の通り、サイズはバラバラです。"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "00c8cbd1de88"
      },
      "outputs": [],
      "source": [
        "import matplotlib.pyplot as plt\n",
        "\n",
        "plt.figure(figsize=(10, 10))\n",
        "for i, (image, label) in enumerate(train_ds.take(9)):\n",
        "    ax = plt.subplot(3, 3, i + 1)\n",
        "    plt.imshow(image)\n",
        "    plt.title(int(label))\n",
        "    plt.axis(\"off\")"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "168c4a10c072"
      },
      "source": [
        "また、ラベル 1 が「犬」、ラベル 0 が「猫」であることもわかります。"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "f749203cd740"
      },
      "source": [
        "### データを標準化する\n",
        "\n",
        "生の画像には様々なサイズがあります。さらに、各ピクセルは 0 ～ 255 の 3 つの整数値（RGB レベル値）で構成されています。これは、ニューラルネットワークへの供給には適しません。次の 2 つを行う必要があります。\n",
        "\n",
        "- 標準化して画像サイズを固定します。 150x150 を選択します。\n",
        "- ピクセル値を -1 〜 1 に正規化します。これはモデル自体の一部として`Normalization`レイヤーを使用して行います。\n",
        "\n",
        "一般的に、すでに処理済みのデータを使用するモデルとは対照的に、入力に生のデータを使用するモデルを開発するのは良い実践です。その理由は、モデルが前処理されたデータを期待していると、モデルをエクスポートして他の場所（ウェブブラウザやモバイルアプリ）で使用する際には、まったく同じ前処理パイプラインを常に再実装する必要が生じるからです。これはすぐに非常に面倒なことになります。だからこそ、モデルを使用する前に可能な限りの前処理を行う必要があるのです。\n",
        "\n",
        "ここでは、（ディープニューラルネットワークは連続したデータバッチしか処理できないので）データパイプラインで画像のリサイズを行い、モデルを作成する際にモデルの一部として入力値のスケーリングを行います。\n",
        "\n",
        "画像を 150×150 にリサイズしてみましょう。"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "b3678f38e087"
      },
      "outputs": [],
      "source": [
        "size = (150, 150)\n",
        "\n",
        "train_ds = train_ds.map(lambda x, y: (tf.image.resize(x, size), y))\n",
        "validation_ds = validation_ds.map(lambda x, y: (tf.image.resize(x, size), y))\n",
        "test_ds = test_ds.map(lambda x, y: (tf.image.resize(x, size), y))"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "708bf9792a35"
      },
      "source": [
        "さらに、データをバッチ処理して、キャッシング＆プリフェッチを使用し、読み込み速度を最適化してみましょう。"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "53ef9e6092e3"
      },
      "outputs": [],
      "source": [
        "batch_size = 32\n",
        "\n",
        "train_ds = train_ds.cache().batch(batch_size).prefetch(buffer_size=10)\n",
        "validation_ds = validation_ds.cache().batch(batch_size).prefetch(buffer_size=10)\n",
        "test_ds = test_ds.cache().batch(batch_size).prefetch(buffer_size=10)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "b60f852c462f"
      },
      "source": [
        "### ランダムデータ拡張を使用する\n",
        "\n",
        "大規模な画像データセットを持っていない場合には、ランダムに水平反転や少し回転を加えるなど、ランダムでありながら現実的な変換をトレーニング画像に適用し、サンプルの多様性を人為的に導入するのが良い実践です。これによって、過適合を遅らせると同時にトレーニングデータの異なった側面にモデルを公開することができます。"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "40b1e355b9c0"
      },
      "outputs": [],
      "source": [
        "from tensorflow import keras\n",
        "from tensorflow.keras import layers\n",
        "\n",
        "data_augmentation = keras.Sequential(\n",
        "    [\n",
        "        layers.experimental.preprocessing.RandomFlip(\"horizontal\"),\n",
        "        layers.experimental.preprocessing.RandomRotation(0.1),\n",
        "    ]\n",
        ")"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "6fa8ddeda36e"
      },
      "source": [
        "さまざまなランダム変換の後、最初のバッチの最初の画像がどのように見えるかを可視化してみましょう。"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "9077f9fd022e"
      },
      "outputs": [],
      "source": [
        "import numpy as np\n",
        "\n",
        "for images, labels in train_ds.take(1):\n",
        "    plt.figure(figsize=(10, 10))\n",
        "    first_image = images[0]\n",
        "    for i in range(9):\n",
        "        ax = plt.subplot(3, 3, i + 1)\n",
        "        augmented_image = data_augmentation(\n",
        "            tf.expand_dims(first_image, 0), training=True\n",
        "        )\n",
        "        plt.imshow(augmented_image[0].numpy().astype(\"int32\"))\n",
        "        plt.title(int(labels[0]))\n",
        "        plt.axis(\"off\")"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "bc999e4672c3"
      },
      "source": [
        "## モデルを構築する\n",
        "\n",
        "では、先ほど説明した青写真に沿ってモデルを構築してみましょう。\n",
        "\n",
        "注意点:\n",
        "\n",
        "- `Normalization` レイヤーを追加して、入力値（最初の範囲は`[0, 255]`）を`[-1, 1]`の範囲にスケーリングします。\n",
        "- 正則化のために、分類レイヤーの前に `Dropout` レイヤーを追加します。\n",
        "- ベースモデルを呼び出す際に `training=False` を渡して推論モードで動作するようにし、ファインチューニングを実行するためにベースモデルを解凍した後でも BatchNorm の統計が更新されないようにします。"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "07a2f9e9d817"
      },
      "outputs": [],
      "source": [
        "base_model = keras.applications.Xception(\n",
        "    weights=\"imagenet\",  # Load weights pre-trained on ImageNet.\n",
        "    input_shape=(150, 150, 3),\n",
        "    include_top=False,\n",
        ")  # Do not include the ImageNet classifier at the top.\n",
        "\n",
        "# Freeze the base_model\n",
        "base_model.trainable = False\n",
        "\n",
        "# Create new model on top\n",
        "inputs = keras.Input(shape=(150, 150, 3))\n",
        "x = data_augmentation(inputs)  # Apply random data augmentation\n",
        "\n",
        "# Pre-trained Xception weights requires that input be normalized\n",
        "# from (0, 255) to a range (-1., +1.), the normalization layer\n",
        "# does the following, outputs = (inputs - mean) / sqrt(var)\n",
        "norm_layer = keras.layers.experimental.preprocessing.Normalization()\n",
        "mean = np.array([127.5] * 3)\n",
        "var = mean ** 2\n",
        "# Scale inputs to [-1, +1]\n",
        "x = norm_layer(x)\n",
        "norm_layer.set_weights([mean, var])\n",
        "\n",
        "# The base model contains batchnorm layers. We want to keep them in inference mode\n",
        "# when we unfreeze the base model for fine-tuning, so we make sure that the\n",
        "# base_model is running in inference mode here.\n",
        "x = base_model(x, training=False)\n",
        "x = keras.layers.GlobalAveragePooling2D()(x)\n",
        "x = keras.layers.Dropout(0.2)(x)  # Regularize with dropout\n",
        "outputs = keras.layers.Dense(1)(x)\n",
        "model = keras.Model(inputs, outputs)\n",
        "\n",
        "model.summary()"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "2e8237de81e8"
      },
      "source": [
        "## トップレイヤーをトレーニングする"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "9137b8daedad"
      },
      "outputs": [],
      "source": [
        "model.compile(\n",
        "    optimizer=keras.optimizers.Adam(),\n",
        "    loss=keras.losses.BinaryCrossentropy(from_logits=True),\n",
        "    metrics=[keras.metrics.BinaryAccuracy()],\n",
        ")\n",
        "\n",
        "epochs = 20\n",
        "model.fit(train_ds, epochs=epochs, validation_data=validation_ds)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "aa51d4562fa7"
      },
      "source": [
        "## モデル全体のファインチューニングを行う\n",
        "\n",
        "最後に、ベースモデルを解凍して、モデル全体のエンドツーエンドを低い学習率でトレーニングしてみましょう。\n",
        "\n",
        "重要なのは、モデル構築時の呼び出しで `training=False` を渡しているため、ベースモデルがトレーニング対象になっても、推論モードで動作しているということです。つまり、内側のバッチ正則化レイヤーのバッチ統計は更新されません。更新してしまうと、それまでにモデルが学習してきた表現が破壊されてしまいます。"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "3cc299505b72"
      },
      "outputs": [],
      "source": [
        "# Unfreeze the base_model. Note that it keeps running in inference mode\n",
        "# since we passed `training=False` when calling it. This means that\n",
        "# the batchnorm layers will not update their batch statistics.\n",
        "# This prevents the batchnorm layers from undoing all the training\n",
        "# we've done so far.\n",
        "base_model.trainable = True\n",
        "model.summary()\n",
        "\n",
        "model.compile(\n",
        "    optimizer=keras.optimizers.Adam(1e-5),  # Low learning rate\n",
        "    loss=keras.losses.BinaryCrossentropy(from_logits=True),\n",
        "    metrics=[keras.metrics.BinaryAccuracy()],\n",
        ")\n",
        "\n",
        "epochs = 10\n",
        "model.fit(train_ds, epochs=epochs, validation_data=validation_ds)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "afa73d989302"
      },
      "source": [
        "10エポック後、ファインチューニングによって有益な改善が得られます。"
      ]
    }
  ],
  "metadata": {
    "colab": {
      "collapsed_sections": [],
      "name": "transfer_learning.ipynb",
      "toc_visible": true
    },
    "kernelspec": {
      "display_name": "Python 3",
      "name": "python3"
    }
  },
  "nbformat": 4,
  "nbformat_minor": 0
}
